import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import scipy.io as sio
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
import platform
from typing import List, Union
from torch.nn import Parameter

from argparse import ArgumentParser

parser = ArgumentParser(description='ISTA-SAM-Net')

parser.add_argument('--tune_epoch', type=int, default=40, help='epoch number of tuning')
# parser.add_argument('--tune_epoch', type=int, default=40, help='epoch number of tuning')
parser.add_argument('--load_epoch', type=int, default=200, help='epoch number of pre_trained model')
parser.add_argument('--layer_num', type=int, default=9, help='phase number of ISTA-Net-plus')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate')
parser.add_argument('--group_num', type=int, default=1, help='group number for training')
parser.add_argument('--cs_ratio', type=int, default=4, help='from {1, 4, 10, 25, 40, 50}')
parser.add_argument('--gpu_list', type=str, default='0', help='gpu index')
parser.add_argument('--sam', default=False, action='store_true')

parser.add_argument('--matrix_dir', type=str, default='sampling_matrix', help='sampling matrix directory')
parser.add_argument('--model_dir', type=str, default='model', help='trained or pre-trained model directory')
parser.add_argument('--data_dir', type=str, default='data', help='training data directory')
parser.add_argument('--log_dir', type=str, default='log', help='log directory')

# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
torch.backends.cuda.matmul.allow_tf32 = False
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = False


class GradientModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.lambda_step = nn.Parameter(torch.Tensor([0.5]))

    def forward(self, x, PhiTPhi, PhiTb):
        grad = self.lambda_step * (torch.mm(x, PhiTPhi) - PhiTb)
        return grad

class Denoiser(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.soft_thr = nn.Parameter(torch.Tensor([0.01]))

        self.conv1_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 1, 3, 3)))
        self.conv2_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv1_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv2_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(1, 32, 3, 3)))

    def forward(self, x):
        x_input = x.view(-1, 1, 33, 33)

        x = F.conv2d(x_input, self.conv1_forward, padding=1)
        x = F.relu(x)
        x_forward = F.conv2d(x, self.conv2_forward, padding=1)

        x = torch.mul(torch.sign(x_forward), F.relu(torch.abs(x_forward) - self.soft_thr))

        x = F.conv2d(x, self.conv1_backward, padding=1)
        x = F.relu(x)
        x_backward = F.conv2d(x, self.conv2_backward, padding=1)

        x_pred = x_backward.view(-1, 1089)

        x = F.conv2d(x_forward, self.conv1_backward, padding=1)
        x = F.relu(x)
        x_est = F.conv2d(x, self.conv2_backward, padding=1)
        symloss = x_est - x_input

        return [x_pred, symloss]


class SharpModule(nn.Module):
    def __init__(self, rho=1e-3):
        super().__init__()
        self.rho = nn.Parameter(torch.tensor([rho]))
        self.gamma = nn.Parameter(torch.tensor([0.9]))

    def forward(self, sub_grad):
        if sub_grad is None:
            return 0.
        norm = torch.sqrt(torch.sum(torch.square(sub_grad), dim=1, keepdim=True)).detach()
        alpha = 1. - self.gamma
        inv_norm = self.gamma / norm
        beta = torch.where(torch.isfinite(inv_norm), self.gamma / norm, torch.zeros_like(norm))
        epsilon = self.rho * (alpha + beta) * sub_grad
        # print(alpha + beta)
        return epsilon


class SubGradientModule(nn.Module):
    def __init__(self, gm: GradientModule):
        super().__init__()
        self.gradient_module = gm

    def forward(self, u, v, PhiTPhi, PhiTb):
        sub_gradient = u - v
        return sub_gradient


# Define ISTA-Net
class ISTANet(torch.nn.Module):
    def __init__(self, LayerNo):
        super(ISTANet, self).__init__()
        self.LayerNo = LayerNo
        self.gradient_modules = nn.ModuleList([
            GradientModule() for _ in range(LayerNo)
        ])
        self.denoisers = nn.ModuleList([
            Denoiser() for _ in range(LayerNo)
        ])

    def layer(self, x, PhiTPhi, PhiTb, t):
        x = x - self.gradient_modules[t](x, PhiTPhi, PhiTb)
        x, layer_sym = self.denoisers[t](x)
        return x, layer_sym

    def forward(self, Phix, Phi, Qinit):

        PhiTPhi = torch.mm(torch.transpose(Phi, 0, 1), Phi)
        PhiTb = torch.mm(Phix, Phi)

        x = torch.mm(Phix, torch.transpose(Qinit, 0, 1))

        layers_sym = []   # for computing symmetric loss

        for i in range(self.LayerNo):
            [x, layer_sym] = self.layer(x, PhiTPhi, PhiTb, i)
            layers_sym.append(layer_sym)
        x_final = x

        return [x_final, layers_sym]


# Define ISTA-Net
class ISTASamNet(torch.nn.Module):
    def __init__(self, LayerNo):
        super().__init__()
        self.LayerNo = LayerNo
        self.gradient_modules = nn.ModuleList([
            GradientModule() for _ in range(LayerNo)
        ])
        self.denoisers = nn.ModuleList([
            Denoiser() for _ in range(LayerNo)
        ])
        self.sub_gradient_modules = nn.ModuleList([
            SubGradientModule(self.gradient_modules[t]) for t in range(LayerNo)
        ])
        self.sam_modules = nn.ModuleList([
            SharpModule() for _ in range(LayerNo)
        ])

    def layer(self, sg, x, PhiTPhi, PhiTb, t):
        epsilon = self.sam_modules[t](sg)
        u = x + epsilon - self.gradient_modules[t](x + epsilon, PhiTPhi, PhiTb)
        v, layer_sym = self.denoisers[t](u)
        sg = self.sub_gradient_modules[t](u, v, PhiTPhi, PhiTb)
        x = v - epsilon
        return sg, x, layer_sym

    def forward(self, Phix, Phi, Qinit):

        PhiTPhi = torch.mm(torch.transpose(Phi, 0, 1), Phi)
        PhiTb = torch.mm(Phix, Phi)

        x = torch.mm(Phix, torch.transpose(Qinit, 0, 1))

        layers_sym = []   # for computing symmetric loss
        sg = None
        for i in range(self.LayerNo):
            [sg, x, layer_sym] = self.layer(sg, x, PhiTPhi, PhiTb, i)
            layers_sym.append(layer_sym)
        x_final = x

        return [x_final, layers_sym]


class RandomDataset(Dataset):
    """
        Impl of Pytorch DataSet API
    """

    def __init__(self, data, length):
        self.data = data
        self.len = length

    def __getitem__(self, index):
        return torch.Tensor(self.data[index, :]).float()

    def __len__(self):
        return self.len


MODEL_DIR_PATTERN = "./%s/CS_%s_layer_%d_group_%d_ratio_%d_lr_%.4f"
LOG_DIR_PATTERN = "./%s/Log_CS_%s_layer_%d_group_%d_ratio_%d_lr_%.4f.txt"
OUTPUT_PATTERN = "[%02d/%02d] Total Loss: %.4f, Discrepancy Loss: %.4f,  Constraint Loss: %.4f"


def load_net(LayerNo, path: str) -> ISTANet:
    model = ISTANet(LayerNo)
    model.load_state_dict(torch.load(path))
    return model


def copy_weights(net: ISTANet, sam_net: ISTASamNet):
    keys = [k for k, _ in net.named_parameters()]
    for name, sp in sam_net.named_parameters():
        if name in keys:
            np = net.get_parameter(name)
            # sp.requires_grad_(False)
            sp.data.copy_(np.data)


def main(
    tune_epoch=40, load_epoch=200,
    layer_num=9, learning_rate=1e-4, group_num=1,
    cs_ratio=25, gpu_id=0, sam=False,
    matrix_dir='sampling_matrix',
    model_dir='model_dir,
    data_dir='data',
    log_dir='log'
):
    # ratio_dict = {1: 10, 4: 43, 10: 109, 25: 272, 30: 327, 40: 436, 50: 545}
    # n_input = ratio_dict[cs_ratio]
    # n_output = 1089

    device = f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu'
    torch.cuda.set_device(device)

    nrtrain = 88912   # number of training blocks
    batch_size = 64

    # Load CS Sampling Matrix: phi
    Phi_data_Name = './%s/phi_0_%d_1089.mat' % (matrix_dir, cs_ratio)
    Phi_data = sio.loadmat(Phi_data_Name)
    Phi_input = Phi_data['phi']

    Training_data_Name = 'Training_Data.mat'
    Training_data = sio.loadmat('./%s/%s' % (data_dir, Training_data_Name))
    Training_labels = Training_data['labels']

    Qinit_Name = './%s/Initialization_Matrix_%d.mat' % (matrix_dir, cs_ratio)

    # Computing Initialization Matrix:
    if os.path.exists(Qinit_Name):
        Qinit_data = sio.loadmat(Qinit_Name)
        Qinit = Qinit_data['Qinit']
    else:
        X_data = Training_labels.transpose()
        Y_data = np.dot(Phi_input, X_data)
        Y_YT = np.dot(Y_data, Y_data.transpose())
        X_YT = np.dot(X_data, Y_data.transpose())
        Qinit = np.dot(X_YT, np.linalg.inv(Y_YT))
        del X_data, Y_data, X_YT, Y_YT
        sio.savemat(Qinit_Name, {'Qinit': Qinit})

    model = ISTASamNet(layer_num)
    model_name = "ISTA_SAM_NET"

    copy_weights(
        load_net(
            layer_num,
            f"{MODEL_DIR_PATTERN}/net_params_%d.pkl" % (
                model_dir, "ISTA_NET", layer_num,
                1, cs_ratio, learning_rate,
                load_epoch
            )
        ),
        model
    )

    model = model.to(device)

    # print_flag = False   # print parameter number

    # if print_flag:
    #     num_count = 0
    #     for para in model.parameters():
    #         num_count += 1
    #         print('Layer %d' % num_count)
    #         print(para.size())

    if (platform.system() == "Windows"):
        rand_loader = DataLoader(
            dataset=RandomDataset(Training_labels, nrtrain),
            batch_size=batch_size, num_workers=0, shuffle=True
        )
    else:
        rand_loader = DataLoader(
            dataset=RandomDataset(Training_labels, nrtrain),
            batch_size=batch_size, num_workers=4, shuffle=True
        )

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    model_dir = MODEL_DIR_PATTERN % (model_dir, model_name, layer_num,
                                     group_num, cs_ratio, learning_rate)
    log_file_name = LOG_DIR_PATTERN % (log_dir, model_name, layer_num,
                                       group_num, cs_ratio, learning_rate)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    # if start_epoch > 0:
    #     pre_model_dir = model_dir
    #     model.load_state_dict(torch.load('./%s/net_params_%d.pkl' % (
    #         pre_model_dir, start_epoch)))

    Phi = torch.from_numpy(Phi_input).type(torch.FloatTensor)
    Phi = Phi.to(device)

    Qinit = torch.from_numpy(Qinit).type(torch.FloatTensor)
    Qinit = Qinit.to(device)

    # Training loop
    for epoch_i in range(tune_epoch):
        for data in rand_loader:
            batch_x = data
            batch_x = batch_x.to(device)

            Phix = torch.mm(batch_x, torch.transpose(Phi, 0, 1))

            [x_output, loss_layers_sym] = model(Phix, Phi, Qinit)
            # Compute and print loss
            loss_discrepancy = torch.mean(torch.pow(x_output - batch_x, 2))
            loss_constraint = torch.mean(torch.pow(loss_layers_sym[0], 2))
            for k in range(layer_num-1):
                loss_constraint += torch.mean(torch.pow(loss_layers_sym[k+1], 2))
            gamma = torch.Tensor([0.01]).to(device)
            # loss_all = loss_discrepancy
            loss_all = loss_discrepancy + torch.mul(gamma, loss_constraint)
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss_all.backward()
            optimizer.step()



            output_data = OUTPUT_PATTERN % (epoch_i, tune_epoch,
                                            loss_all.item(),
                                            loss_discrepancy.item(),
                                            loss_constraint)
            print(output_data)

        output_file = open(log_file_name, 'a')
        output_file.write(output_data + '\n')
        output_file.close()

        if epoch_i % 5 == 0:
            torch.save(model.state_dict(), "./%s/net_params_tune_%d.pkl" % (
                model_dir, epoch_i
            ))


if __name__ == '__main__':
    args = parser.parse_args()
    main(
        args.tune_epoch,
        args.load_epoch,
        args.layer_num,
        args.learning_rate,
        args.group_num,
        args.cs_ratio,
        args.gpu_list,
        sam=args.sam,
    )
